import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
from matplotlib import animation
import matplotlib.ticker as ticker
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import compare_ptycho_to_rpi, compare_rpi_to_rpi, hann_window,  get_circular_lineout, fourier_pad

# This gives the plots a clear background when saved to PDF
plt.rcParams.update({
    "figure.facecolor":  (1.0, 1.0, 1.0, 0.0),  # clear
    "axes.facecolor":    (1.0, 1.0, 1.0, 1.0),  # white
    "savefig.facecolor": (1.0, 1.0, 1.0, 0.0),  # clear
})

#
# A few configuration parameters for the script
#

# If True, will save the raw image output for all the imshows, as well as
# the versions embedded in axes. The raw images are used for the paper
# figures for image data
save_output = False
if save_output:
    from pathlib import Path
    Path('figs').mkdir(exist_ok=True)

# The exposure to start with
start = 0

# The number of exposures to plot
n_exposures = 128

# Which summed images to plot (e.g, a sum of 1 image, one of 2 images, ...)
plot_sums = [1,2,4,12,64]


# We load the raw datasets:
ptycho_filename = 'single_shot_ptycho.cxi'
rpi_filename = 'single_shot_rpi.cxi'

def make_window(center, radius):
    return np.s_[center[0] - radius : center[0] + radius,
                 center[1] - radius : center[1] + radius]


#
# Load the ptychography results
#

# This is the Fourier-space padding used for the ptycho reconstruction
padding = 200
ptycho_results = loadmat(('results/' + '.'.join(ptycho_filename.split('.')[:-1])
                          + '_synthesized_polished_%dpad.mat' % padding))

# The center of the RPI FOV, in the ptycho reconstruction
ptycho_center = (1193-23, 1160+35)
# The "radius" (half-height / half-width) of the region to extract from
# the ptycho reconstruction
ptycho_radius = 500
ptycho_window = make_window(ptycho_center, ptycho_radius)


#
# Load the RPI results
#

# This was the total width of the low-resolution object representation
resolution = 650
rpi_results = loadmat('results/' + '.'.join(rpi_filename.split('.')[:-1])
                      + ('_2modes_%d.mat' % resolution))


#
# Define a helper function to save image data
#

def save_latest_im(filename):
    im = plt.gca().get_images()[0]
    plt.imsave(filename, im.get_array())



#
# Load some RPI data to show a raw diffraction pattern
#


electronsperphoton = 60 / 3.6
ADUperelectron =2**16 / 200000 # from the MTE-2048B datasheet
ADUperphoton = ADUperelectron * electronsperphoton
# The full-well capacity 200,000 electrons, and 

print('ADU per electron', ADUperelectron)
print('ADU per photon', ADUperphoton)

rpi_dataset = cdtools.datasets.Ptycho2DDataset.from_cxi('data/' + rpi_filename)
example_pattern = rpi_dataset.patterns[start] / ADUperphoton
example_pattern = t.clamp(example_pattern, min=0)

p.plot_real(t.log10(example_pattern + 1))

if save_output:
    save_latest_im('figs/RPI_example_pattern.png')
    plt.savefig('figs/RPI_example_pattern.pdf')

    
total_photons = t.sum(rpi_dataset.patterns[:128], dim=(1,2)) / ADUperphoton
mean_photons = t.mean(total_photons)
print('Mean Photons Per Shot:', mean_photons)
mean_photons_per_det_pixel = mean_photons / (1024**2)
mean_photons_per_res_element = mean_photons / 24672 # divided by SBP
print('Mean Photons Per Pixel:', mean_photons_per_det_pixel)
print('Mean Photons Per Resolution Element:', mean_photons_per_res_element)
print('RPI Number of Pixels:', rpi_results['objs'][0].shape[-1])
print('RPI Pixel Size:', np.abs(rpi_results['basis'][0,1]))
print('Ptycho Pixel Size:', np.abs(ptycho_results['basis'][0,1]))

#
# Now I estimate the illumination fluence on the sample
# 

quantum_efficiency = 0.42 # From PI-MTE 2048 Datasheet, at 60 eV.
energy_per_photon = 60 * 1.6e-19 # in Joules
membrane_attenuation = 0.55 # calculated for 30 nm of SiN at 60 eV
siemens_star_attenuation = 0.35 # estimated from the ptycho result

total_energy_at_det = mean_photons / (quantum_efficiency) * energy_per_photon
print('Total energy at Detector', 1e9 * total_energy_at_det, 'nJ')

energy_incident_on_sample = total_energy_at_det / (membrane_attenuation * siemens_star_attenuation)
print('Energy incident on sample', 1e9 * energy_incident_on_sample, 'nJ')

flux = energy_incident_on_sample / (np.pi*(20e-6)**2) # flux in J/m^2
print('Flux incident on sample', 1e2 * flux, 'uJ/cm^2')



plt.figure(figsize=(3.75,2.5))
plt.plot(total_photons)
plt.xlabel('Shot Index')
plt.ylabel('Measured Photons')
plt.tight_layout()

if save_output:
    plt.savefig('figs/Photons_per_shot.pdf')


#
# Normalize the RPI results
#

# We run the comparison script between the ptychography result and
# all of RPI reconstructions, which will center all the RPI results
# to a common location.

single_shot_comparisons = []
summed_ims = []

for i in range(start, start + n_exposures):
    print(f'Comparing image {i} with ptycho')
    rpi_result = {'obj': rpi_results['objs'][i],
                  'basis': rpi_results['basis']}
    
    comp = compare_ptycho_to_rpi(ptycho_results, rpi_result, ptycho_window)
    single_shot_comparisons.append(comp)

    if i == start:
        summed_ims.append(comp['shifted_rpi_obj'])
    else:
        summed_ims.append(summed_ims[-1] + comp['shifted_rpi_obj'])


        
summed_ims = [im / (idx + 1) for idx, im in enumerate(summed_ims)]

#
# Calculate resolution metrics for all the summed images
#

summed_comparisons = []
for i, im in enumerate(summed_ims):
    print(f'Comparing sum of {i+1} image{"" if i==0 else "s"} with ptycho')
    summed_results = {'obj': im,
                      'basis': rpi_results['basis']}
    summed_comparisons.append(
        compare_ptycho_to_rpi(ptycho_results, summed_results, ptycho_window))


#
# Now we plot an example single-shot RPI result
#

crop = 120
f = plt.figure(dpi=300)
p.plot_colorized(rpi_results['objs'][0][np.s_[crop:-crop,crop:-crop]],
                 basis=rpi_results['basis'], fig=f)
if save_output:
    save_latest_im('figs/RPI_raw_colorized.png')
    plt.savefig('figs/RPI_raw_colorized.pdf')
    
f = plt.figure(dpi=300)
p.plot_amplitude(rpi_results['objs'][0][np.s_[crop:-crop,crop:-crop]],
                 basis=rpi_results['basis'], fig=f)
if save_output:
    save_latest_im('figs/RPI_raw_amplitude.png')
    plt.savefig('figs/RPI_raw_amplitude.pdf')

f = plt.figure(dpi=300)
p.plot_phase(rpi_results['objs'][0][np.s_[crop:-crop,crop:-crop]],
             basis=rpi_results['basis'], fig=f)
plt.clim([-np.pi,np.pi])
if save_output:
    save_latest_im('figs/RPI_raw_phase.png')
    plt.savefig('figs/RPI_raw_phase.pdf')


# Now we plot zoomed images from 

#
# And now we plot amplitude and colorized images of the summed images
#

for i in plot_sums:
    f = plt.figure(dpi=300)
    p.plot_colorized(summed_comparisons[i-1]['rpi_obj'],
                     basis=rpi_results['basis'], fig=f)
    if save_output:
        save_latest_im(f'figs/RPI_sum_{i}_colorized.png')
        plt.savefig(f'figs/RPI_sum_{i}_colorized.pdf')
        
    f = plt.figure(dpi=300)
    p.plot_amplitude(summed_comparisons[i-1]['rpi_obj'],
                     basis=rpi_results['basis'], fig=f)
    if save_output:
        save_latest_im(f'figs/RPI_sum_{i}_amplitude.png')
        plt.savefig(f'figs/RPI_sum_{i}_amplitude.pdf')
        
    f = plt.figure(dpi=300)
    p.plot_phase(summed_comparisons[i-1]['rpi_obj'],
                  basis=rpi_results['basis'], fig=f)
    plt.clim([-np.pi,np.pi])
    if save_output:
        save_latest_im(f'figs/RPI_sum_{i}_phase.png')
        plt.savefig(f'figs/RPI_sum_{i}_phase.pdf')


#
# Now we plot the FRCs and SSNRs
#

frcs = np.array([comp['frc'] for comp in summed_comparisons])
ssnrs = np.array([comp['ssnr'] for comp in summed_comparisons])

frc_fig = plt.figure(figsize=(3.75,2.5))
ssnr_fig= plt.figure(figsize=(3.75,2.5))
max_ssnrs = []
R_1_ssnrs = []
epss = []

# First we plot the greyed-out single-shot comparions
for comp in single_shot_comparisons:
    plt.figure(frc_fig)
    plt.plot(comp['frc_freqs']*1e-6, comp['frc'], 'grey')
    plt.figure(ssnr_fig)
    plt.semilogy(comp['frc_freqs']*1e-6, comp['ssnr'], 'grey')


for comp in summed_comparisons:
    max_ssnrs.append(np.max(comp['ssnr'][2:]))
    epss.append(comp['eps'])
    R_1_freq_idx = np.nonzero(comp['frc_freqs'] > 2.5e6)[0][0]
    # TODO: Is this actually where R=1??
    R_1_ssnrs.append(comp['ssnr'][R_1_freq_idx])

    
for i in plot_sums:
    comp = summed_comparisons[i-1]
    
    label = f'{i} shot{"" if i==0 else "s"}'
    
    plt.figure(frc_fig)
    plt.plot(comp['frc_freqs']*1e-6, comp['frc'], label=label)
    plt.figure(ssnr_fig)
    plt.semilogy(comp['frc_freqs']*1e-6, comp['ssnr'], label=label)


plt.figure(frc_fig)
plt.grid(True)
plt.xlabel('Full Pitch Spatial Frequency (cycles / um)')
plt.ylabel('FCR (ptycho vs RPI)')
plt.legend()

plt.figure(ssnr_fig)
plt.semilogy(ptycho_results['frc_freqs'][0] * 1e-6, ptycho_results['ssnr'][0], 'k', label='ptycho')
plt.grid(True, which='both')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('SSNR')
plt.legend()
plt.grid(True, which='both')
plt.plot(comp['frc_freqs']*1e-6, 0.4142 + 0*comp['ssnr'],'k--')
plt.xlim([0, np.max(comp['frc_freqs'])*1e-6])
plt.tight_layout()

plt.figure(figsize=(3.75,2.5))
plt.plot([idx + 1 for idx in range(len(max_ssnrs))], max_ssnrs)
plt.grid(True)
plt.xlabel('Number of shots')
plt.ylabel('Maximum nonzero SSNR value')
plt.tight_layout()

plt.figure(figsize=(3.75,2.5))
plt.semilogy([idx + 1 for idx in range(len(epss))], epss)
plt.grid(True)
plt.xlabel('Number of shots')
plt.ylabel('Normalized MSE')
plt.tight_layout()

plt.figure(figsize=(3.75,2.5))
plt.plot([idx + 1 for idx in range(len(R_1_ssnrs))], R_1_ssnrs)
plt.grid(True)
plt.xlabel('Number of shots')
plt.ylabel('SSNR at R=1')
plt.tight_layout()


#
# Now we plot the explicit line-cuts
#


# Step 1: We define the images to use for the linecuts, giving the images,
# their pixel size, and the center of the Siemens star in each image.

rpi_siemens_center = [221,194.75]
ptycho_siemens_center = [1192 - 750, 1159.5 - 750]

linecut_ims = [
    {
        'label' : 'Ptychography',
        'im': ptycho_results['obj_full'][750:1600,750:1600],
        'center' : ptycho_siemens_center,
        'pix_size' : np.abs(ptycho_results['basis'][0,1]),
        'color' : 'black',
    },
    {
        'label' : '64-Shot RPI',
        'im': summed_comparisons[-1]['rpi_obj'],
        'center' : rpi_siemens_center,
        'pix_size' : np.abs(rpi_results['basis'][0,1]),
        'color' : 'tab:purple',
    },
    {
        'label' : '1-Shot RPI',
        'im': summed_comparisons[0]['rpi_obj'],
        'center' : rpi_siemens_center,
        'pix_size' : np.abs(rpi_results['basis'][0,1]),
        'color' : 'tab:orange',
    },
]

# This adjusts the radius for the plot
fp_resolution = 400e-9 # 400 nm

# We will sinc-upsample the images to the following pixel size for the linecut
upsampled_pix_size = 10e-9

linecut_fig_real = plt.figure(figsize=(12,2.5))
linecut_fig_abs = plt.figure(figsize=(12,2.5))
linecut_fig_phase = plt.figure(figsize=(12,2.5))

for result in linecut_ims:

    # First, we plot the linecut location on the original (non-upsampled) image
    circumference = (fp_resolution / result['pix_size']) * 64
    radius = circumference / (2 * np.pi)

    thetas, values, points = get_circular_lineout(
        result['im'], result['center'], radius)
    xs = thetas * radius * result['pix_size'] 

    plt.figure()
    plt.imshow(np.abs(result['im']))
    plt.plot(points[:,1], points[:,0],'k-', linewidth=1)
    linecut_mask = xs <= 2.5e-6
    plt.plot(points[linecut_mask,1], points[linecut_mask,0],
             '-', color=result['color'], linewidth=2)
    

    # Then, we upsample the image and calculate the actual final linecut

    upsampling_factor = result['pix_size'] / upsampled_pix_size
    # We need the images to be square: they are
    padding = int((result['im'].shape[0] * (upsampling_factor - 1)) / 2)
    upsampled_im = fourier_pad(t.as_tensor(result['im']), padding).numpy()

    # We recalculate upsampled_pix_size basedon what was actually achieved
    achieved_upsampling_factor = upsampled_im.shape[0] / result['im'].shape[0]
    achieved_upsampled_pix_size = result['pix_size']/ achieved_upsampling_factor
    

    upsampled_circumference = (fp_resolution / achieved_upsampled_pix_size) * 64
    upsampled_radius = upsampled_circumference / (2 * np.pi)
    upsampled_center = [c *achieved_upsampling_factor for c in result['center']]
    thetas, values, points = get_circular_lineout(
        upsampled_im, upsampled_center, upsampled_radius)

    xs = thetas * radius * result['pix_size'] 
    
    plt.figure(linecut_fig_real)
    plt.plot(xs * 1e6, np.real(values) / np.quantile(np.abs(values), 0.95),
             label=result['label'], color=result['color'])
    plt.figure(linecut_fig_abs)
    plt.plot(xs * 1e6, np.abs(values) / np.quantile(np.abs(values), 0.95),
             label=result['label'], color=result['color'])
    plt.figure(linecut_fig_phase)
    plt.plot(xs * 1e6, np.angle(values),
             label=result['label'], color=result['color'])

plt.figure(linecut_fig_real)
plt.ylim([-0.25,1.25])
plt.xlabel('Arc Length (um)')
plt.ylabel('Real Part (arb. units)')
plt.legend()
plt.tight_layout()

plt.figure(linecut_fig_abs)
plt.ylim([-0.25,1.25])
plt.xlabel('Arc Length (um)')
plt.ylabel('Amplitude (arb. units)')
plt.legend()
plt.tight_layout()

plt.figure(linecut_fig_phase)
plt.ylim([-np.pi,np.pi])
plt.xlabel('Arc Length (um)')
plt.ylabel('Phase (rad)')
plt.legend()
plt.tight_layout()
plt.show()
